# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Automatically generated by addcopyright.py at 01/29/2013
'''
Created on Jan 2, 2013

@author: frank
'''
import cherrypy
import sglib
import xmlobject
import types
import uuid
import os.path
import sys
import os

class SGRule(object):
    def __init__(self):
        self.protocol = None
        self.start_port = None
        self.end_port = None
        self.allowed_ips = []

class IPSet(object):
    IPSET_TYPE = 'hash:ip'
    def __init__(self, setname, ips):
        self.ips = ips
        self.name = setname
    
    def create(self):
        tmpname = str(uuid.uuid4()).replace('-', '')[0:30]
        sglib.ShellCmd('ipset -N %s %s' % (tmpname, self.IPSET_TYPE))()
        try:
            for ip in self.ips:
                sglib.ShellCmd('ipset -A %s %s' % (tmpname, ip))()
            
            try:
                sglib.ShellCmd('ipset -N %s %s' % (self.name, self.IPSET_TYPE))()
                cherrypy.log('created new ipset: %s' % self.name)
            except Exception:
                cherrypy.log('%s already exists, no need to create new' % self.name)
        finally:
            sglib.ShellCmd('ipset -W %s %s' % (tmpname, self.name))()
            sglib.ShellCmd('ipset -F %s' % tmpname)()
            sglib.ShellCmd('ipset -X %s' % tmpname)()
            
    @staticmethod 
    def destroy_sets(sets_to_keep):
        sets = sglib.ShellCmd('ipset list')()
        for s in sets.split('\n'):
            if 'Name:' in s:
                set_name = s.split(':', 1)[1].strip()
                if not set_name in sets_to_keep:
                    sglib.ShellCmd('ipset destroy %s' % set_name)()
                    cherrypy.log('destroyed unused ipset: %s' % set_name)
        
class SGAgent(object):
    def __init__(self):
        pass
    
    def _self_list(self, obj):
        if isinstance(obj, types.ListType):
            return obj
        else:
            return [obj]
        
    def set_rules(self, req):
        body = req.body
        doc = xmlobject.loads(body)
        vm_name = doc.vmName.text_
        vm_id = doc.vmId.text_
        vm_ip = doc.vmIp.text_
        vm_mac = doc.vmMac.text_
        sig = doc.signature.text_
        seq = doc.sequenceNumber.text_
        
        def parse_rules(rules, lst):
            for i in self._self_list(rules):
                r = SGRule()
                r.protocol = i.protocol.text_
                r.start_port = i.startPort.text_
                r.end_port = i.endPort.text_
                if hasattr(i, 'ip'):
                    for ip in self._self_list(i.ip):
                        r.allowed_ips.append(ip.text_)
                lst.append(r)
            
        i_rules = []
        if hasattr(doc, 'ingressRules'):
            parse_rules(doc.ingressRules, i_rules)
            
        e_rules = []
        if hasattr(doc, 'egressRules'):
            parse_rules(doc.egressRules, e_rules)
            
        def create_chain(name):
            try:
                sglib.ShellCmd('iptables -F %s' % name)()
            except Exception:
                sglib.ShellCmd('iptables -N %s' % name)()
            
        def apply_rules(rules, chainname, direction, action, current_set_names):
            create_chain(chainname)
            for r in i_rules:
                allow_any = False
                if '0.0.0.0/0' in r.allowed_ips:
                    allow_any = True
                    r.allowed_ips.remove('0.0.0.0/0')
                
                if r.allowed_ips:
                    setname = '_'.join([chainname, r.protocol, r.start_port, r.end_port])
                    ipset = IPSet(setname, r.allowed_ips)
                    ipset.create()
                    current_set_names.append(setname)
                    
                    if r.protocol == 'all':
                        cmd = ['iptables -I', chainname, '-m state --state NEW -m set --set', setname, direction, '-j', action]
                        sglib.ShellCmd(' '.join(cmd))()
                    elif r.protocol != 'icmp':
                        port_range = ":".join([r.start_port, r.end_port])
                        cmd = ['iptables', '-I', chainname, '-p', r.protocol, '-m', r.protocol, '--dport', port_range, '-m state --state NEW -m set --set', setname, direction, '-j', action]
                        sglib.ShellCmd(' '.join(cmd))()
                    else:
                        port_range = "/".join([r.start_port, r.end_port])
                        if r.start_port == "-1":
                            port_range = "any"
                        cmd = ['iptables', '-I', i_chain_name, '-p', 'icmp', '--icmp-type', port_range, '-m set --set', setname, direction, '-j', action]
                        sglib.ShellCmd(' '.join(cmd))()
                        
                    
                if allow_any and r.protocol != 'all':
                    if r.protocol != 'icmp':
                        port_range = ":".join([r.start_port, r.end_port])
                        cmd = ['iptables', '-I', chainname, '-p', r.protocol, '-m', r.protocol, '--dport', port_range, '-m', 'state', '--state', 'NEW', '-j', action]
                        sglib.ShellCmd(' '.join(cmd))()
                    else:
                        port_range = "/".join([r.start_port, r.end_port])
                        if r.start_port == "-1":
                            port_range = "any"
                        cmd = ['iptables', '-I', i_chain_name, '-p', 'icmp', '--icmp-type', port_range, '-j', action]
                        sglib.ShellCmd(' '.join(cmd))()
        
        current_sets = []
        i_chain_name = vm_name + '-in'
        apply_rules(i_rules, i_chain_name, 'src', 'ACCEPT', current_sets)
        e_chain_name = vm_name + '-eg'
        apply_rules(e_rules, e_chain_name, 'dst', 'RETURN', current_sets)
        
        if e_rules:
            sglib.ShellCmd('iptables -A %s -j RETURN' % e_chain_name)
        else:
            sglib.ShellCmd('iptables -A %s -j DROP' % e_chain_name)
        
        sglib.ShellCmd('iptables -A %s -j DROP' % i_chain_name)
        IPSet.destroy_sets(current_sets)
                
        
    def echo(self, req):
        cherrypy.log("echo: I am alive")
        
    def index(self):
        req = sglib.Request.from_cherrypy_request(cherrypy.request)
        cmd_name = req.headers['command']
        
        if not hasattr(self, cmd_name):
            raise ValueError("SecurityGroupAgent doesn't have a method called '%s'" % cmd_name)           
        method = getattr(self, cmd_name)
        
        return method(req)
    index.exposed = True
    
    @staticmethod
    def start():
        cherrypy.log.access_file = '/var/log/cs-securitygroup.log'
        cherrypy.log.error_file = '/var/log/cs-securitygroup.log'
        cherrypy.server.socket_host = '0.0.0.0'
        cherrypy.server.socket_port = 9988
        cherrypy.quickstart(SGAgent())
        
    @staticmethod 
    def stop():
        cherrypy.engine.exit()

PID_FILE = '/var/run/cssgagent.pid'
class SGAgentDaemon(sglib.Daemon):
    def __init__(self):
        super(SGAgentDaemon, self).__init__(PID_FILE)
        self.is_stopped = False
        self.agent = SGAgent()
        sglib.Daemon.register_atexit_hook(self._do_stop)
    
    def _do_stop(self):
        if self.is_stopped:
            return
        self.is_stopped = True
        self.agent.stop()
    
    def run(self):
        self.agent.start()
        
    def stop(self):
        self.agent.stop()
        super(SGAgentDaemon, self).stop()

def main():
    usage = 'usage: python -c "from security_group_agent import cs_sg_agent; cs_sg_agent.main()" start|stop|restart'
    if len(sys.argv) != 2 or not sys.argv[1] in ['start', 'stop', 'restart']:
        print usage
        sys.exit(1)
    
    cmd = sys.argv[1]
    agentdaemon = SGAgentDaemon()
    if cmd == 'start':
        agentdaemon.start()
    elif cmd == 'stop':
        agentdaemon.stop()
    else:
        agentdaemon.restart()
        
    sys.exit(0)
